Vision Transformer
https://scrapbox.io/files/64f3004b6bc192001b3f1662.png
構造図
入力画像をpatch単位に分割
input layerからの出力は、class tokenおよび各パッチに対応するベクトルになっている。
CLSトークン = class tokenを指す
CLSトークンと各パッチのベクトルをTransformer Encoderに入力し、特徴抽出処理
Encoder内部では、Encoder Blockが積み重なっている。
Encoder Block内部では、Self-attention層とMLP層がある。
https://scrapbox.io/files/64f303e0ac909c001c4364a3.png
元の画像を$ \bm{x}\in\mathbb{R}^{H\times\ W \times C}から、flattenして入力するため、実際の入力の形は
$ \bm{x}_p\in\mathbb{R}^{N\times(P^2\cdot C)}である。$ Nはパッチ数、$ Pはパッチの大きさで、正方形のパッチである。
すなわち、$ N=\frac{HW}{P^2}を満たす。
各パッチの画像を横並べしたやつが$ N本あると思えば良い
https://scrapbox.io/files/64f30411702e7e001b8d0dd0.png
ViT Encoderの入力は、Patchベクトルたちを更に埋め込んだものである。
この埋め込みには線形変換を用いる。すなわち、この埋め込みに使用するベクトルを$ \bm{E}\in\mathbb{R}^{(P^2\cdot C)\times D}とし、Positional Encoding$ \bm{E}_{pos}\in\mathbb{R}^{(N+1)\times D}を使用する。clsトークンも含めてpositional encodingを行う。したがって、Embedded patches$ \bm{z}_0=\lbrack\bm{x}_{cls};\bm{x}_p^1\bm{E};\bm{x}_p^2\bm{E};...;\bm{x}_p^N\bm{E}\rbrack+\bm{E}_{pos}
上付きインデックスは各パッチを表す。$ \bm{z}_0\in\mathbb{R}^{(N+1)\times D}であることに注意
(各パッチの次元$ Dの潜在表現が$ N本とclsトークンの潜在表現からなる。)
https://scrapbox.io/files/64f3072a8bf00c001c66e592.jpg
さて、実装の話になると、重要な点として、以下はこれと等価なコードである。
$ \lbrack\bm{x}_p^1\bm{E};\bm{x}_p^2\bm{E};...;\bm{x}_p^N\bm{E}\rbrack
code:python
self.patch_emb_layer = nn.Conv2d(in_channels, out_channels,kernel_size=p,stride=p)
ViTBlockからの出力は、$ \lbrack B,N,D\rbrackになっているはず。